RBONN: Recurrent Bilinear Optimization for a Binary Neural Network
81
3.8.1
Bilinear Model of BNNs
We formulate the optimization of BNNs as follows.
arg min
w,A
LS(w, A) + λG(w, A),
(3.122)
where λ is the hyper-parameter. G contains the bilinear part as mentioned in Eq. 6.34.
w and A formulate a pair of coupled variables. Thus, the conventional gradient descent
method can be used to solve the bilinear optimization problem as
At+1 = |At −η1
∂L
∂At |,
(3.123)
( ∂L
∂At )T = (∂LS
∂At )T + λ( ∂G
∂At )T ,
= ( ∂LS
∂at
out
∂at
out
∂At )T + λwt(Atwt −bwt)T ,
= ( ∂LS
∂at
out
)T (bat
in ⊙bwt)(At)−2 + λwt ˆG(wt, At),
(3.124)
where η1 is the learning rate, ˆG(wt, At) = (Atwt −bwt)T . The conventional gradient
descent algorithm for bilinear models iteratively optimizes one variable while keeping the
other fixed. This is a suboptimal solution due to ignoring the relationship between the two
hidden variables in optimization. For example, when w approaches zero due to the sparsity
regularization term R(w), A will have a larger magnitude due to G (Eq. 6.34). Consequently,
both the first and second values of Eq. 6.70 will be dramatically suppressed, causing the
gradient vanishing problem for A. Contrarily, if A changes little during optimization, w will
also suffer from the vanished gradient problem due to the supervision of G, causing a local
minimum. Due to the coupling relationship of w and A, the gradient calculation for w is
challenging.
3.8.2
Recurrent Bilinear Optimization
We solve the problem in Eq. 6.34 from a new perspective that w and A are coupled. We
aim to prevent A from becoming denser and w from becoming sparser, as analyzed above.
Firstly, based on the chain rule and its notations in [187], we have the scalar form of the
update rule for wi,j as
wt+1
i,j
= wt
i,j −η2
∂LS
∂wt
i,j
−η2λ( ∂G
∂wt
i,j
+ Tr(( ∂G
∂At )T ∂At
∂wt
i,j
)),
= wt+1
i,j −η2λTr(wt ˆG(wt, At) ∂At
∂wt
i,j
),
(3.125)
which is based on wt+1
i,j
= wt
i,j −η2
∂L
∂wt
i,j . ˆwt+1 denotes w in the t + 1-th iteration when
considering the coupling of w and A. When computing the gradient of the coupled variable
w, the gradient of its coupled variable A should also be considered using the chain rule.
Vanilla wt+1 denotes the computed w at t+1-th iteration without considering the coupling
relationship. Here, we denote I = Cout and J = Cin × K × K for simplicity. With writing
w in a row vector [w1, · · · , wI]T and writing ˆG in a column vector [ˆg1, · · · , ˆgI] and using
i = 1, · · · , I and j = 1, · · · , J, we can see that Ai,i and wnj are independent when ∀n ̸= j.